D2GRs2 源码阅读3-DataSet
├── data
│ ├── dataset.py # 构建序列 dataset 时间逆序
│ ├── eval.py # 评估指标,并记录到TensorBoard
│ ├── item_features.py # 定义`ItemFeatures`数据类
│ ├── preprocessor.py # 预处理两种数据集(下载、处理等)
│ └── reco_dataset.py # 获取推荐(训练、评估)数据集
├── trainer
│ └── data_loader.py # 数据加载器、支持分布式训练
├── train.py # 训练脚本
# train.py
dataset = get_reco_dataset(
dataset_name=dataset_name, # ml-1m
max_sequence_length=max_sequence_length, # 200
chronological=True, # 按时间排序
positional_sampling_ratio=positional_sampling_ratio, # 1 按位置采样
)
train_data_sampler, train_data_loader = create_data_loader(
dataset.train_dataset,
batch_size=local_batch_size,
world_size=world_size,
rank=rank,
shuffle=True,
drop_last=world_size > 1,
)
- @dataclass
@dataclass装饰器,这个类能够存储一个数字,拥有比大小的功能,很大程度上减少了代码量,很方便。除了上面的整型外,还可以使用其他的类型,包括自己定义的数据类型。深度学习pytorch之dataclass
from dataclasses import dataclass
@dataclass
class RecoDataset:
max_sequence_length: int
num_unique_items: int
max_item_id: int
all_item_ids: List[int]
train_dataset: torch.utils.data.Dataset
eval_dataset: torch.utils.data.Dataset
- Dataset
Dataset是PyTorch提供的一个抽象类,我们可以继承这个类并重写
__getitem__
和__len__
方法,从而创建自己的数据集。__getitem__
方法用于获取单个数据样本,__len__
方法则返回数据集的大小。
from torch.utils.data import Dataset
import os
class CustomDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.images = os.listdir(data_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.data_dir, self.images[idx])
image = Image.open(img_path)
label = idx # 这里为了简化,我们直接用索引作为标签
return image, label
- Dataloader
DataLoader是PyTorch提供的一个数据加载器,它可以从Dataset中读取数据,并以批次的形式提供给模型进行训练。DataLoader的主要参数包括:
- dataset:输入的数据集,必须是Dataset对象。
- batch_size:每个批次的数据量。
- shuffle:是否在每个epoch开始时打乱数据。
- num_workers:用于数据加载的子进程数。
from torch.utils.data import DataLoader
# 假设我们已经创建了一个CustomDataset对象
dataset = CustomDataset(data_dir='./data', transform=transform)
# 创建一个DataLoader对象
data_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
# 在训练循环中使用DataLoader
for epoch in range(num_epochs):
for batch_idx, (data, target) in enumerate (data_loader):
# 在这里进行模型的训练操作
pass
Tensorflow模型的格式 - chease - 博客园